{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "64189f24",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第8章/加载编码工具\n",
    "from transformers import BertTokenizer\n",
    "\n",
    "token = BertTokenizer.from_pretrained('bert-base-chinese')\n",
    "\n",
    "token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "7b6bbdad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids torch.Size([2, 18])\n",
      "token_type_ids torch.Size([2, 18])\n",
      "length torch.Size([2])\n",
      "attention_mask torch.Size([2, 18])\n",
      "[CLS] 轻 轻 的 我 走 了 ， 正 如 我 轻 轻 地 来 。 [SEP] [PAD]\n"
     ]
    }
   ],
   "source": [
    "#第8章/试编码句子\n",
    "out = token.batch_encode_plus(\n",
    "    batch_text_or_text_pairs=['轻轻的我走了，正如我轻轻地来。', '我轻轻的招手，作别西天的云彩。'],\n",
    "    truncation=True,\n",
    "    padding='max_length',\n",
    "    max_length=18,\n",
    "    return_tensors='pt',\n",
    "    return_length=True)\n",
    "\n",
    "#查看编码输出\n",
    "for k, v in out.items():\n",
    "    print(k, v.shape)\n",
    "\n",
    "#把编码还原为句子\n",
    "print(token.decode(out['input_ids'][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "dfab7774",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 9600\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 0\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 1200\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第8章/加载数据集\n",
    "from datasets import load_from_disk\n",
    "\n",
    "dataset = load_from_disk('./data/ChnSentiCorp')\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a31d262e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/train/cache-d3283c8583b24ed8.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/train/cache-27dc37d5bd6706e4.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/train/cache-826f6ca6e106aacc.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/train/cache-835f419ddabc745f.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/test/cache-682897291cf43603.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/test/cache-9d6e5e25ce6b17d4.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/test/cache-a0072e02f56e2af6.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/test/cache-f0333c24f02eb083.arrow\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input_ids', 'token_type_ids', 'length', 'attention_mask'],\n",
       "        num_rows: 9600\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: [],\n",
       "        num_rows: 0\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input_ids', 'token_type_ids', 'length', 'attention_mask'],\n",
       "        num_rows: 1200\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第8章/编码数据,同时删除多余的字段\n",
    "def f(data):\n",
    "    return token.batch_encode_plus(batch_text_or_text_pairs=data['text'],\n",
    "                                   truncation=True,\n",
    "                                   padding='max_length',\n",
    "                                   max_length=30,\n",
    "                                   return_length=True)\n",
    "\n",
    "\n",
    "dataset = dataset.map(function=f,\n",
    "                      batched=True,\n",
    "                      batch_size=1000,\n",
    "                      num_proc=4,\n",
    "                      remove_columns=['text', 'label'])\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "6e44c172",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/train/cache-431e5ba849eb7708_00000_of_00004.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/train/cache-431e5ba849eb7708_00001_of_00004.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/train/cache-431e5ba849eb7708_00002_of_00004.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/train/cache-431e5ba849eb7708_00003_of_00004.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/test/cache-f5af1b4bd662e257_00000_of_00004.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/test/cache-f5af1b4bd662e257_00001_of_00004.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/test/cache-f5af1b4bd662e257_00002_of_00004.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/test/cache-f5af1b4bd662e257_00003_of_00004.arrow\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input_ids', 'token_type_ids', 'length', 'attention_mask'],\n",
       "        num_rows: 9286\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: [],\n",
       "        num_rows: 0\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input_ids', 'token_type_ids', 'length', 'attention_mask'],\n",
       "        num_rows: 1157\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第8章/过滤掉太短的句子\n",
    "def f(data):\n",
    "    return [i >= 30 for i in data['length']]\n",
    "\n",
    "\n",
    "dataset = dataset.filter(function=f, batched=True, batch_size=1000, num_proc=4)\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "02e7d76d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cuda'"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第8章/定义计算设备\n",
    "import torch\n",
    "\n",
    "device = 'cpu'\n",
    "if torch.cuda.is_available():\n",
    "    device = 'cuda'\n",
    "\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "0faddebf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#第8章/数据整理函数\n",
    "def collate_fn(data):\n",
    "    #取出编码结果\n",
    "    input_ids = [i['input_ids'] for i in data]\n",
    "    attention_mask = [i['attention_mask'] for i in data]\n",
    "    token_type_ids = [i['token_type_ids'] for i in data]\n",
    "\n",
    "    #转换为tensor格式\n",
    "    input_ids = torch.LongTensor(input_ids)\n",
    "    attention_mask = torch.LongTensor(attention_mask)\n",
    "    token_type_ids = torch.LongTensor(token_type_ids)\n",
    "\n",
    "    #把第15个词替换为mask\n",
    "    labels = input_ids[:, 15].reshape(-1).clone()\n",
    "    input_ids[:, 15] = token.get_vocab()[token.mask_token]\n",
    "\n",
    "    #移动到计算设备\n",
    "    input_ids = input_ids.to(device)\n",
    "    attention_mask = attention_mask.to(device)\n",
    "    token_type_ids = token_type_ids.to(device)\n",
    "    labels = labels.to(device)\n",
    "\n",
    "    return input_ids, attention_mask, token_type_ids, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "f985a056",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CLS] 我 是 法 语 初 学 者, 学 了 78 个 课 时 [MASK] 初 级 班. 因 为 我 是 老 年 人, 法 [SEP]\n",
      "的\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(torch.Size([2, 30]),\n",
       " torch.Size([2, 30]),\n",
       " torch.Size([2, 30]),\n",
       " tensor([4638, 2692], device='cuda:0'))"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第8章/数据整理函数试算\n",
    "#模拟一批数据\n",
    "data = [{\n",
    "    'input_ids': [\n",
    "        101, 2769, 3221, 3791, 6427, 1159, 2110, 5442, 117, 2110, 749, 8409,\n",
    "        702, 6440, 3198, 4638, 1159, 5277, 4408, 119, 1728, 711, 2769, 3221,\n",
    "        5439, 2399, 782, 117, 3791, 102\n",
    "    ],\n",
    "    'token_type_ids': [0] * 30,\n",
    "    'attention_mask': [1] * 30\n",
    "}, {\n",
    "    'input_ids': [\n",
    "        101, 679, 7231, 8024, 2376, 3301, 1351, 6848, 4638, 8024, 3301, 1351,\n",
    "        3683, 6772, 4007, 2692, 8024, 2218, 3221, 100, 2970, 1366, 2208, 749,\n",
    "        8024, 5445, 684, 1059, 3221, 102\n",
    "    ],\n",
    "    'token_type_ids': [0] * 30,\n",
    "    'attention_mask': [1] * 30\n",
    "}]\n",
    "\n",
    "#试算\n",
    "input_ids, attention_mask, token_type_ids, labels = collate_fn(data)\n",
    "\n",
    "#把编码还原为句子\n",
    "print(token.decode(input_ids[0]))\n",
    "print(token.decode(labels[0]))\n",
    "\n",
    "input_ids.shape, attention_mask.shape, token_type_ids.shape, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "929eca83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "580"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第8章/定义数据加载器\n",
    "loader = torch.utils.data.DataLoader(dataset=dataset['train'],\n",
    "                                     batch_size=16,\n",
    "                                     collate_fn=collate_fn,\n",
    "                                     shuffle=True,\n",
    "                                     drop_last=True)\n",
    "\n",
    "len(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "ca4c4175",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CLS] 位 于 友 谊 路 金 融 街 ， 找 不 到 吃 饭 [MASK] 地 方 。 酒 店 刚 刚 装 修 好 ， 有 点 [SEP]\n",
      "的\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 30]),\n",
       " torch.Size([16, 30]),\n",
       " torch.Size([16, 30]),\n",
       " tensor([4638, 6230,  511, 7313, 3221, 7315, 6820, 6858, 7564, 3211, 1690, 3315,\n",
       "         3300,  172, 6821, 1126], device='cuda:0'))"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第8章/查看数据样例\n",
    "for i, (input_ids, attention_mask, token_type_ids,\n",
    "        labels) in enumerate(loader):\n",
    "    break\n",
    "\n",
    "print(token.decode(input_ids[0]))\n",
    "print(token.decode(labels[0]))\n",
    "input_ids.shape, attention_mask.shape, token_type_ids.shape, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "66291b57",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "10226.7648"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第8章/加载预训练模型\n",
    "from transformers import BertModel\n",
    "\n",
    "pretrained = BertModel.from_pretrained('bert-base-chinese')\n",
    "\n",
    "#统计参数量\n",
    "sum(i.numel() for i in pretrained.parameters()) / 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "1dbea6bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "#第8章/不训练预训练模型,不需要计算梯度\n",
    "for param in pretrained.parameters():\n",
    "    param.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "dddf4825",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 30, 768])"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第8章/预训练模型试算\n",
    "#设定计算设备\n",
    "pretrained.to(device)\n",
    "\n",
    "#模型试算\n",
    "out = pretrained(input_ids=input_ids,\n",
    "                 attention_mask=attention_mask,\n",
    "                 token_type_ids=token_type_ids)\n",
    "\n",
    "out.last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "8c1379c2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 21128])"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第8章/定义下游任务模型\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.decoder = torch.nn.Linear(in_features=768,\n",
    "                                       out_features=token.vocab_size,\n",
    "                                       bias=False)\n",
    "        #重新初始化decode中的bias参数为全0\n",
    "        self.bias = torch.nn.Parameter(data=torch.zeros(token.vocab_size))\n",
    "        self.decoder.bias = self.bias\n",
    "\n",
    "        #定义Dropout层，防止过拟合\n",
    "        self.dropout = torch.nn.Dropout(p=0.5)\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, token_type_ids):\n",
    "        #使用预训练模型抽取数据特征\n",
    "        with torch.no_grad():\n",
    "            out = pretrained(input_ids=input_ids,\n",
    "                             attention_mask=attention_mask,\n",
    "                             token_type_ids=token_type_ids)\n",
    "\n",
    "        #把第15个词的特征，投影到全字典范围内\n",
    "        out = self.dropout(out.last_hidden_state[:, 15])\n",
    "        out = self.decoder(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "model = Model()\n",
    "\n",
    "#设定计算设备\n",
    "model.to(device)\n",
    "\n",
    "#试算\n",
    "model(input_ids=input_ids,\n",
    "      attention_mask=attention_mask,\n",
    "      token_type_ids=token_type_ids).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "a981b886",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 10.022448539733887 0.0004998275862068966 0.0\n",
      "0 50 8.73752498626709 0.0004912068965517241 0.1875\n",
      "0 100 7.15378475189209 0.0004825862068965518 0.25\n",
      "0 150 6.036799907684326 0.0004739655172413793 0.25\n",
      "0 200 6.4759111404418945 0.0004653448275862069 0.0625\n",
      "0 250 3.800313949584961 0.00045672413793103453 0.4375\n",
      "0 300 7.0236616134643555 0.00044810344827586206 0.25\n",
      "0 350 5.194925785064697 0.00043948275862068964 0.3125\n",
      "0 400 5.884705543518066 0.0004308620689655173 0.3125\n",
      "0 450 4.168199062347412 0.00042224137931034486 0.4375\n",
      "0 500 6.240730285644531 0.0004136206896551724 0.375\n",
      "0 550 4.36335563659668 0.00040500000000000003 0.375\n",
      "1 0 3.574946403503418 0.00039982758620689656 0.375\n",
      "1 50 4.219260215759277 0.00039120689655172415 0.375\n",
      "1 100 3.1496996879577637 0.00038258620689655173 0.625\n",
      "1 150 3.0767054557800293 0.0003739655172413793 0.375\n",
      "1 200 3.6137592792510986 0.0003653448275862069 0.5625\n",
      "1 250 3.3886961936950684 0.0003567241379310345 0.5\n",
      "1 300 5.3483662605285645 0.0003481034482758621 0.4375\n",
      "1 350 2.7506301403045654 0.00033948275862068965 0.625\n",
      "1 400 3.5999977588653564 0.00033086206896551723 0.5625\n",
      "1 450 2.456437349319458 0.0003222413793103448 0.6875\n",
      "1 500 2.7866811752319336 0.00031362068965517246 0.5625\n",
      "1 550 3.411165714263916 0.000305 0.5625\n",
      "2 0 3.2547690868377686 0.0002998275862068965 0.5625\n",
      "2 50 2.0145423412323 0.00029120689655172415 0.75\n",
      "2 100 2.3726096153259277 0.00028258620689655174 0.5625\n",
      "2 150 1.8401336669921875 0.0002739655172413793 0.75\n",
      "2 200 3.0410377979278564 0.0002653448275862069 0.4375\n",
      "2 250 2.980194568634033 0.0002567241379310345 0.3125\n",
      "2 300 2.7839863300323486 0.0002481034482758621 0.375\n",
      "2 350 3.1279046535491943 0.00023948275862068966 0.4375\n",
      "2 400 3.3245158195495605 0.00023086206896551727 0.5625\n",
      "2 450 3.731590747833252 0.00022224137931034483 0.5\n",
      "2 500 2.3065927028656006 0.0002136206896551724 0.6875\n",
      "2 550 3.200794219970703 0.000205 0.375\n",
      "3 0 4.259110927581787 0.00019982758620689655 0.4375\n",
      "3 50 2.6592702865600586 0.00019120689655172414 0.75\n",
      "3 100 2.2059292793273926 0.00018258620689655172 0.75\n",
      "3 150 2.5569658279418945 0.00017396551724137933 0.6875\n",
      "3 200 1.9693695306777954 0.0001653448275862069 0.875\n",
      "3 250 1.3077285289764404 0.0001567241379310345 0.9375\n",
      "3 300 1.9754962921142578 0.00014810344827586208 0.6875\n",
      "3 350 2.6310343742370605 0.00013948275862068964 0.5\n",
      "3 400 2.6864442825317383 0.00013086206896551725 0.75\n",
      "3 450 2.837418556213379 0.00012224137931034484 0.625\n",
      "3 500 2.5199856758117676 0.00011362068965517242 0.75\n",
      "3 550 2.2130799293518066 0.000105 0.6875\n",
      "4 0 3.369119644165039 9.982758620689655e-05 0.625\n",
      "4 50 2.5900590419769287 9.120689655172413e-05 0.4375\n",
      "4 100 2.212355613708496 8.258620689655173e-05 0.6875\n",
      "4 150 3.9292068481445312 7.396551724137931e-05 0.4375\n",
      "4 200 1.7726680040359497 6.53448275862069e-05 0.75\n",
      "4 250 2.4024291038513184 5.672413793103448e-05 0.5625\n",
      "4 300 2.8472514152526855 4.810344827586207e-05 0.625\n",
      "4 350 2.0372190475463867 3.9482758620689656e-05 0.8125\n",
      "4 400 2.575105667114258 3.086206896551724e-05 0.625\n",
      "4 450 1.9375989437103271 2.2241379310344828e-05 0.75\n",
      "4 500 2.0469861030578613 1.3620689655172415e-05 0.6875\n",
      "4 550 2.005430221557617 5e-06 0.8125\n"
     ]
    }
   ],
   "source": [
    "#第8章/训练\n",
    "from transformers import AdamW\n",
    "from transformers.optimization import get_scheduler\n",
    "\n",
    "\n",
    "def train():\n",
    "    #定义优化器\n",
    "    optimizer = AdamW(model.parameters(), lr=5e-4, weight_decay=1.0)\n",
    "\n",
    "    #定义loss函数\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "    #定义学习率调节器\n",
    "    scheduler = get_scheduler(name='linear',\n",
    "                              num_warmup_steps=0,\n",
    "                              num_training_steps=len(loader) * 5,\n",
    "                              optimizer=optimizer)\n",
    "\n",
    "    #模型切换到训练模式\n",
    "    model.train()\n",
    "\n",
    "    #共训练5个epoch\n",
    "    for epoch in range(5):\n",
    "        #按批次遍历训练集中的数据\n",
    "        for i, (input_ids, attention_mask, token_type_ids,\n",
    "                labels) in enumerate(loader):\n",
    "\n",
    "            #模型计算\n",
    "            out = model(input_ids=input_ids,\n",
    "                        attention_mask=attention_mask,\n",
    "                        token_type_ids=token_type_ids)\n",
    "\n",
    "            #计算loss并使用梯度下降法优化模型参数\n",
    "            loss = criterion(out, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            scheduler.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            #输出各项数据的情况，便于观察\n",
    "            if i % 50 == 0:\n",
    "                out = out.argmax(dim=1)\n",
    "                accuracy = (out == labels).sum().item() / len(labels)\n",
    "                lr = optimizer.state_dict()['param_groups'][0]['lr']\n",
    "                print(epoch, i, loss.item(), lr, accuracy)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "4dcaad80",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "0.5645833333333333\n"
     ]
    }
   ],
   "source": [
    "#第8章/测试\n",
    "def test():\n",
    "    #定义测试数据集加载器\n",
    "    loader_test = torch.utils.data.DataLoader(dataset=dataset['test'],\n",
    "                                              batch_size=32,\n",
    "                                              collate_fn=collate_fn,\n",
    "                                              shuffle=True,\n",
    "                                              drop_last=True)\n",
    "\n",
    "    #下游任务模型切换到运行模式\n",
    "    model.eval()\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    #按批次遍历测试集中的数据\n",
    "    for i, (input_ids, attention_mask, token_type_ids,\n",
    "            labels) in enumerate(loader_test):\n",
    "\n",
    "        #计算15个批次即可，不需要全部遍历\n",
    "        if i == 15:\n",
    "            break\n",
    "\n",
    "        print(i)\n",
    "\n",
    "        #计算\n",
    "        with torch.no_grad():\n",
    "            out = model(input_ids=input_ids,\n",
    "                        attention_mask=attention_mask,\n",
    "                        token_type_ids=token_type_ids)\n",
    "\n",
    "        #统计正确率\n",
    "        out = out.argmax(dim=1)\n",
    "        correct += (out == labels).sum().item()\n",
    "        total += len(labels)\n",
    "\n",
    "    print(correct / total)\n",
    "\n",
    "\n",
    "test()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
